# -*- coding: utf-8 -*-
# @Time : 2020/9/10 下午9:50
# @Author : rebeater
# @File : loosely_couple_core.py
# @Project: LooselyCouple_2020
# @Function: Loosely Coupled Algorithm Core Functions


import numpy as np
import prettytable
import yaml
import copy

import ins_core as ins
# import ins_core as ins
from ins_core import _deg, _mGal, _ppm, _hour, _sqrt_hour, ImuPara, NavData, WGS84

GNSS_TXT_POS_7 = 0
GNSS_BIN_POS_7 = 1
GNSS_TXT_POS_14 = 2
GNSS_BIN_POS_14 = 3
RTKLIB_TXT_POS = 4
GNSS_TXT_POS_VEL = 5
RESERVED = 6


class Option:
    def __init__(self, yaml_path):
        with open(yaml_path) as f:
            args = yaml.safe_load(f)
            print(args)

        self.start_time = args["start-time"]
        self.end_time = args["end-time"]
        self.imu_filepath = args["imu-path"]
        self.output_filepath = args["output-path"]
        self.gnss_path = args["gnss-path"]
        self.imu_parameter_cfg = args["imu-parameter-cfg"]
        self.imu_data_rate = args["imu-data-rate"]
        self.gnss_data_rate = args["gnss-data-rate"]
        self.alignment_mode = args["alignment-mode"]
        self.alignment_epoch = NavData(args["alignment-epoch"])
        self.init_pos_std = np.array(args["init-pos-std"], dtype=np.double)
        self.init_vel_std = np.array(args["init-vel-std"], dtype=np.double)
        self.init_atti_std = np.deg2rad(np.array(args["init-atti-std"], dtype=np.double))
        self.nhc_enable = False  # TODO 未完成 = args["nhc-enable"]
        self.nhc_hz = 0
        self.zupt_enable = False  # TODO 未完成
        self.zupt_hz = 0
        self.odo_enable = False
        self.odo_hz = 0
        self.odo_wheel_radius = 0
        self.gnss_format = args["gnss-format"]
        # GNSS_TXT_POS_7 = 0,# GNSS_BIN_POS_7 = 1,GNSS_TXT_POS_14 = 2,GNSS_BIN_POS_14 = 3,
        # RTKLIB_TXT_POS = 4,GNSS_TXT_POS_VEL = 5,RESERVED = 6,
        self.antenna_level_arm = np.array(args["antenna-level-arm"], dtype=np.double)

        self.imu_para = ImuPara(self.imu_parameter_cfg)

    def show(self):
        table = prettytable.PrettyTable()
        table.field_names = ["option", "value", "unit"]
        table.add_row(["start-time", self.start_time, "GPS second"])
        table.add_row(["end-time", self.end_time, "GPS second"])
        table.add_row(["imu-filepath", self.imu_filepath.split("/")[-1], ""])
        table.add_row(["gnss-filepath", self.gnss_path.split("/")[-1], ""])
        table.add_row(["output-filepath", self.output_filepath.split("/")[-1], ""])
        table.add_row(["imu-parameter-cfg", self.imu_parameter_cfg.split("/")[-1], ""])
        table.add_row(["imu-data-rate", self.imu_data_rate, "Hz"])
        table.add_row(["gnss-data-rate", self.gnss_data_rate, "Hz"])
        table.add_row(["alignment-mode", self.alignment_mode, ""])
        table.add_row(["alignment-epoch", self.alignment_epoch.time_s, ""])
        table.add_row(["init_pos_std", self.init_pos_std, "m"])
        table.add_row(["init_vel_std", self.init_pos_std, "m/s"])
        table.add_row(["init_atti_std", np.rad2deg(self.init_atti_std), "deg"])
        table.add_row(["antenna-level-arm", self.antenna_level_arm, "m"])
        table.add_row(["NHC enable", self.nhc_enable, ""])
        table.add_row(["ZUPT enable", self.zupt_enable, ""])
        table.add_row(["Odometer enable", self.odo_enable, ""])
        table.add_row(["arw", self.imu_para.arw / _deg * _sqrt_hour, "deg/sqrt(h)"])
        table.add_row(["vrw", self.imu_para.vrw * _sqrt_hour, "m/s/sqrt(h)"])
        table.add_row(["gb-std", self.imu_para.gb_std / _deg * _hour, "deg/h"])
        table.add_row(["ab-std", self.imu_para.ab_std / _mGal, "mGal"])
        table.add_row(["gs-std", self.imu_para.gs_std / _ppm, "ppm"])
        table.add_row(["as-std", self.imu_para.as_std / _ppm, "ppm"])
        table.add_row(["gb-ini", self.imu_para.gb_ini / _deg * _hour, "deg/h"])
        table.add_row(["ab-ini", self.imu_para.ab_ini / _mGal, "mGal"])
        table.add_row(["gs-ini", self.imu_para.gs_ini / _ppm, "ppm"])
        table.add_row(["as-ini", self.imu_para.as_ini / _ppm, "ppm"])
        table.add_row(["gt_corr", self.imu_para.gt_corr / _hour, "h"])
        table.add_row(["at_corr", self.imu_para.at_corr / _hour, "h"])
        # table.add_row(["gnss-filepath",self.gnss_path])
        print(table)

    def __str__(self):
        return ""


class KalmanFilter:
    def __init__(self, opt: Option):
        """
        初始化，包括P,Q矩阵
        :param opt:
        """
        self.state_cnt = 15  # 15维度状态向量
        imu_para = opt.imu_para
        self.mat_p = np.diag([
            opt.init_pos_std[0], opt.init_pos_std[1], opt.init_pos_std[2],  # 初始位置标准差
            opt.init_vel_std[0], opt.init_vel_std[1], opt.init_vel_std[2],  # 初始位置标准差
            opt.init_atti_std[0], opt.init_atti_std[1], opt.init_atti_std[2],  # 初始位置标准差
            imu_para.gb_std[0], imu_para.gb_std[1], imu_para.gb_std[2],
            imu_para.ab_std[0], imu_para.ab_std[1], imu_para.ab_std[2],
        ]) ** 2  # 对角矩阵对角线元素平方
        self.mat_q0 = np.diag([
            0, 0, 0,
            imu_para.vrw ** 2, imu_para.vrw ** 2, imu_para.vrw ** 2,
            imu_para.arw ** 2, imu_para.arw ** 2, imu_para.arw ** 2,
            2 * (imu_para.gb_std[0] ** 2) / imu_para.gt_corr,
            2 * (imu_para.gb_std[1] ** 2) / imu_para.gt_corr,
            2 * (imu_para.gb_std[2] ** 2) / imu_para.gt_corr,
            2 * (imu_para.ab_std[0] ** 2) / imu_para.at_corr,
            2 * (imu_para.ab_std[1] ** 2) / imu_para.at_corr,
            2 * (imu_para.ab_std[2] ** 2) / imu_para.at_corr,
        ])
        self.data_rate = opt.imu_data_rate
        self.antenna_level_arm = opt.antenna_level_arm
        # self.data_rate = 0;
        self.at_corr = imu_para.at_corr
        self.gt_corr = imu_para.gt_corr
        self.xd = np.zeros(self.state_cnt, dtype=np.double)

        self.wgs84 = WGS84()

    def predict(self, last_epoch: NavData):
        delta_t = 1.0 / self.data_rate  # cur_imu[0] - last_epoch.time_s  # 1/data_rate?
        mat_phi = self.__mat_phi(last_epoch, delta_t)
        mat_q = 0.5 * (mat_phi @ self.mat_q0 + self.mat_q0 @ mat_phi.T) * delta_t
        self.xd = mat_phi @ self.xd
        self.mat_p = mat_phi @ self.mat_p @ mat_phi.T + mat_q

    def measure_update(self, cur_gnss, last_epoch):
        """
        卡尔曼更新
        :param cur_gnss:[0] gnss second, [1]lat(rad) , [2]lon(rad),  [3]alt(m),  [4]n-std,  [5] e-std ,[6] d-std
        :param last_epoch:
        :return:
        """
        mat_r = np.diag(cur_gnss[4:7]) ** 2
        mat_h = self.__mat_h(last_epoch)
        mat_h_p_h_r = mat_h @ self.mat_p @ mat_h.transpose() + mat_r
        try:
            mat_h_p_h_r = np.linalg.cholesky(mat_h_p_h_r).T
            mat_u = np.linalg.inv(mat_h_p_h_r)
            mat_u = mat_u @ mat_u.transpose()
            mat_k = self.mat_p @ mat_h.transpose() @ mat_u  # 15x3
            z = self.__zr(cur_gnss[1:4], last_epoch)
            self.xd = self.xd + mat_k @ (z - mat_h @ self.xd)
            '''更新P阵'''
            mat_i_k_h = np.eye(self.state_cnt, dtype=np.double) - mat_k @ mat_h

            self.mat_p = mat_i_k_h @ self.mat_p @ mat_i_k_h.T + mat_k @ mat_r @ mat_k.T
            self.mat_p = (self.mat_p + self.mat_p.T) / 2.0
            current_epoch = self.output(last_epoch)
        except Exception as e:
            print("\nH_P_H_R is not positive definite!!!")
            print(e.args)
            # x_last = np.zeros_like(x_pre,np.double)
            '''输出导航结果'''
            current_epoch = last_epoch
        return current_epoch

    def output(self, last_epoch):
        current_epoch = last_epoch
        latitude = last_epoch.latitude
        height = last_epoch.height + self.xd[2]
        dvn = self.xd[3:6]

        d_atti = np.array([
            self.xd[1] / (self.wgs84.RN(latitude) + height),
            -self.xd[0] / (self.wgs84.RM(latitude) + height),
            -self.xd[1] * np.tan(latitude) / (self.wgs84.RN(latitude) + height)
        ], dtype=np.double)
        '''位置输出'''
        qne = last_epoch.Qne
        qnc = ins.rv_to_quaternion(-d_atti)  #
        qne = qne * qnc
        new_latitude, new_longitude = ins.qne_to_lla(qne.normalize())
        '''速度输出'''
        # Vn = vc -c dvn + np.cross(dAtti, vc) - np.cross(vc, x_last[6:9] + dAtti)
        ccn = np.eye(3, dtype=np.double) + ins.skew_sym(d_atti)
        vn = np.matmul(ccn, last_epoch.Vn - dvn)
        ''' 姿态输出'''
        phi = self.xd[6:9] + d_atti
        qpn = ins.rv_to_quaternion(phi)
        qbn = (qpn * last_epoch.Qbn).normalize()
        current_dcm_b_n = ins.quaternion_to_dcm(qbn)  # 先变成dcm再变成姿态角
        # C_pn = np.eye(3,dtype = np.double) + ins.skew_sym(phi)
        # current_dcm_b_n = C_pn*lastepcho.Cbn

        euler_angle = ins.dcm_to_euler(current_dcm_b_n)
        current_epoch.roll = euler_angle[0]  # dcm2eularAngle(temp)
        current_epoch.pitch = euler_angle[1]
        current_epoch.heading = euler_angle[2]
        current_epoch.atti = euler_angle
        current_epoch.v_f_k_b = last_epoch.v_f_k_b

        current_epoch.Cbn = current_dcm_b_n  # ins.euler2dcm(euler_angle[0], euler_angle[1], euler_angle[2])

        current_epoch.latitude = new_latitude
        current_epoch.longitude = new_longitude
        current_epoch.height = height
        current_epoch.Qne = qne
        current_epoch.Qbn = qbn
        current_epoch.vn = vn[0]
        current_epoch.ve = vn[1]
        current_epoch.vg = vn[2]
        current_epoch.Vn = vn
        current_epoch.dvn = last_epoch.dvn
        # 很重要的一步,更新完成之后xd归零
        current_epoch.gyro_bias += self.xd[9:12]
        current_epoch.acce_bias += self.xd[12:15]
        self.xd = np.zeros(self.state_cnt, dtype=np.double)
        return current_epoch

    def __mat_phi(self, last_epoch: NavData, delta_t):
        """
        PHI阵计算,有的论文是放在机械编排中，并且每个历元的结果累计相乘，这里参考Shin E.H (2005)的论文，采用低频更新
        :param last_epoch:
        :param delta_t:
        :return:
        """
        latitude = last_epoch.latitude
        height = last_epoch.height
        vn = last_epoch.Vn
        wgs84 = WGS84()
        mat_phi = np.zeros([self.state_cnt, self.state_cnt], dtype=np.double)

        omega_en_n = wgs84.omega_en_n(vn[0], vn[1], height, latitude)
        omega_ie_n = wgs84.omega_ie_n(latitude)

        '''F_rr'''
        rm = wgs84.RM(latitude)
        rn = wgs84.RN(latitude)
        i3 = np.eye(3, dtype=np.double)
        mat_phi_11 = i3 + ins.skew_sym(-omega_en_n) * delta_t

        # f_12 = np.eye(3, dtype=np.double)
        mat_phi_12 = i3 * delta_t

        g = wgs84.g(latitude, height)
        mat_phi_21 = np.diag([
            -g / (rm + height), -g / (rn + height), 2 * g / (np.sqrt(rm * rn) + height)
        ]) * delta_t
        mat_phi_22 = i3 - ins.skew_sym(2 * omega_ie_n + omega_en_n) * delta_t

        cbn = last_epoch.Cbn

        # dAtti = np.array([
        #     x_pre[1] / (rn+ height),
        #     -x_pre[0] / (rm + height),
        #     -x_pre[1] * np.tan(latitude) / (rn + height)
        # ])
        # cnc=np.eye(3,dtype=np.double)-ins.skew_sym(dAtti)
        mat_phi_23 = ins.skew_sym(last_epoch.v_f_k_b) * delta_t

        # PHI_27 = np.matmul(cbn, np.diag(cur_imu[4:7] / delta_t)) * delta_t

        mat_phi_33 = i3 - ins.skew_sym(omega_en_n + omega_ie_n) * delta_t

        # PHI_36 = np.matmul(cbn, np.diag(cur_imu[1:4] / delta_t)) * delta_t

        mat_phi_delta_t = np.eye(6, dtype=np.double) - np.eye(6, dtype=np.double) * delta_t / self.gt_corr

        mat_phi[0:3, 0:3] = mat_phi_11
        mat_phi[0:3, 3:6] = mat_phi_12
        mat_phi[3:6, 0:3] = mat_phi_21
        mat_phi[3:6, 3:6] = mat_phi_22
        mat_phi[3:6, 6:9] = mat_phi_23
        mat_phi[3:6, 12:15] = cbn * delta_t
        mat_phi[6:9, 6:9] = mat_phi_33
        mat_phi[6:9, 9:12] = -cbn * delta_t
        mat_phi[9:15, 9:15] = mat_phi_delta_t

        return mat_phi

    def __mat_h(self, last_epoch):
        cbn = last_epoch.Cbn
        hr = np.zeros([3, self.state_cnt], dtype=np.double)
        hr[0:3, 0:3] = np.eye(3, dtype=np.double)
        hr[0:3, 6:9] = ins.skew_sym(cbn @ self.antenna_level_arm)  # Cbn * np.mat(lb_Gnss).T)
        return hr

    def __zr(self, gnss_pos, last_epoch: ins.NavData):
        latitude = gnss_pos[0]
        longitude = gnss_pos[1]
        cne = np.array([
            [-np.sin(latitude) * np.cos(longitude), -np.sin(longitude), -np.cos(latitude) * np.cos(longitude)],
            [-np.sin(latitude) * np.sin(longitude), np.cos(longitude), -np.cos(latitude) * np.sin(longitude)],
            [np.cos(latitude), 0.0, -np.sin(latitude)]
        ], dtype=np.double)
        re = self.wgs84.lla_to_ecef(np.array([last_epoch.latitude, last_epoch.longitude, last_epoch.height]))
        re_m = self.wgs84.lla_to_ecef(gnss_pos)
        de = re - re_m  # e系位置想减
        cbn = last_epoch.Cbn
        z = cne.transpose() @ de + cbn @ self.antenna_level_arm
        return z


def split_data(data_raw, start_time, end_time, index, f: int):
    """
    :param data_raw: 原始数据
    :param start_time: 开始时间
    :param end_time: 结束时间
    :param index: 时间索引，即时间序列所在列标
    :param f:数据频率
    :return:
    """
    k = 0
    init_time = data_raw[k, index]
    while init_time < start_time:
        k = k + 1
        init_time = data_raw[k, index]
    bias = k
    if bias < 0:
        print("no data matches start")
        exit(3)
    if end_time != -1:
        return data_raw[bias:f * (end_time - start_time) + bias + 1, :]
    else:
        return data_raw[bias:-1, :]


def load_gnss(gnss_path: str, gnss_format: int):
    """
    加载GNSS数据
    :param gnss_path:文件路径
    :param gnss_format: 数据格式
    :return: gps_t/s lat/rad lon/rad h/m n-std/m e-std/m d-std/m
    """
    if gnss_format == GNSS_TXT_POS_7:
        gnss_data = np.loadtxt(gnss_path)
        gnss_data[:, 1:3] *= (np.pi / 180.0)
    elif gnss_format == GNSS_TXT_POS_14:
        gnss_data = np.loadtxt(gnss_path)
        gnss_data[:, 1:3] *= (np.pi / 180.0)
        gnss_data[:, 4:7] = gnss_data[:, 7:10]
    else:
        raise ValueError("gnss_format %d is not support yet" % gnss_format)
    return gnss_data


def post_loosely_couple(opt: Option,signal = None):
    """
    后处理函数入口
    :param opt:
    :return:
    """
    opt.show()

    imu_data = np.loadtxt(opt.imu_filepath)
    imu_data = split_data(imu_data, opt.start_time, opt.end_time, 0, 200)
    gnss_data = load_gnss(opt.gnss_path, opt.gnss_format)
    gnss_data = split_data(gnss_data, opt.start_time, opt.end_time, 0, 1)

    kf = KalmanFilter(opt)
    epoch_cnt = imu_data.shape[0]
    cur_epoch = opt.alignment_epoch
    cur_epoch.acce_bias = opt.imu_para.ab_ini
    cur_epoch.gyro_bias = opt.imu_para.gb_ini
    idx_gps = 1
    output = open(opt.output_filepath, 'w')
    for idx_imu in range(1, epoch_cnt):
        # 机械编排
        cur_epoch = ins.mechanization_n(imu_data[idx_imu], imu_data[idx_imu - 1], cur_epoch)
        # 卡尔曼预测
        kf.predict(cur_epoch)
        if idx_gps < gnss_data.shape[0] and np.fabs(
                gnss_data[idx_gps, 0] - imu_data[idx_imu, 0]) < 1.0 / opt.imu_data_rate:
            # 量测更新
            cur_epoch = kf.measure_update(gnss_data[idx_gps], cur_epoch)
            idx_gps += 1
        ins.write2txt(output, cur_epoch)
        if signal is None:
            ins.progress(100 * idx_imu / epoch_cnt)
        else:
            signal.emit(100*idx_imu/epoch_cnt)
    print("\nfinished")
    output.close()


def outage_evalution(opt: Option, start_time=1000, step=10, time_outage=100):
    """
    模拟GNSS中断评测
    :param opt:
    :param start_time: 开始评测时间
    :param step: 两次中断之间的间隔
    :param time_outage: 每次中断时间
    :return:None
    """
    opt.show()

    time_start_outage = opt.start_time + start_time
    time_step_outage = step
    outage_file = open(opt.imu_filepath + ".%ds.otg" % time_outage, 'w')

    imu_data = np.loadtxt(opt.imu_filepath)
    imu_data = split_data(imu_data, opt.start_time, opt.end_time, 0, opt.imu_data_rate)
    gnss_data = load_gnss(opt.gnss_path, opt.gnss_format)
    gnss_data = split_data(gnss_data, opt.start_time, opt.end_time, 0, opt.gnss_data_rate)

    kf = KalmanFilter(opt)
    epoch_cnt = imu_data.shape[0]
    cur_epoch = opt.alignment_epoch
    cur_epoch.acce_bias = opt.imu_para.ab_ini
    cur_epoch.gyro_bias = opt.imu_para.gb_ini
    idx_gps = 1

    output = open(opt.output_filepath, 'w')
    for idx_imu in range(1, epoch_cnt):
        # 机械编排
        cur_epoch = ins.mechanization_n(imu_data[idx_imu], imu_data[idx_imu - 1], cur_epoch)
        # 卡尔曼预测
        kf.predict(cur_epoch)

        # 　中断评测,保留从该时刻起机械编排time_outage秒的机械编排结果，仅使用机械编排向前推算
        if time_start_outage < imu_data[idx_imu, 0] < imu_data[-((1 + time_outage) * opt.imu_data_rate), 0]:
            cur_epoch_outage = copy.copy(cur_epoch)  # 浅层次复制
            for k in range(idx_imu + 1, idx_imu + time_outage * opt.imu_data_rate):
                cur_epoch_outage = ins.mechanization_n(imu_data[k], imu_data[k - 1], cur_epoch_outage)
            ins.write2txt(outage_file, cur_epoch_outage)
            # dn,de,dd,d = ins.get_distance()
            time_start_outage += time_step_outage

        if idx_gps < gnss_data.shape[0] and \
                np.fabs(gnss_data[idx_gps, 0] - imu_data[idx_imu, 0]) < 1.0 / opt.imu_data_rate:
            # 量测更新
            cur_epoch = kf.measure_update(gnss_data[idx_gps], cur_epoch)
            idx_gps += 1
        ins.write2txt(output, cur_epoch)
        ins.progress(100 * idx_imu / epoch_cnt)
    print("\nfinished")
    output.close()
